from collections import defaultdict
import sys
import os
import numpy as np
from . import structurereader as sr

def write_structure(st, name=None, append=False, ext='', **options):
    sw = StructureWriter()
    if type(st).__name__ == 'Structure':
        sw.st = st
    elif type(st).__name__ == 'Graph':
        sw.st = sr.conver_structure(st,"graph")
    else:
        print("Error! You should either specify a structure object or a graph object")
        sys.exit()
    if not name:
        name = sw.st.basename
    if not ext:
        name,ext = os.path.splitext(name)
        if not ext:
            ext='pdb'
        else:
            ext = ext[1:]
    else:
        name = name
        ext = ext
    sw.basename = name
    filename = name + '.' + ext
    if append:
        print('Appending to file {:s}'.format(filename))
        sw.file = open(filename, 'a')
    else:
        print('Generating file {:s}'.format(filename))
        sw.file = open(filename, 'w')
    # if ':' in options:
    #     sw.options = defaultdict(str, {v.split(':')[0]: v.split(':')[1] for v in options.split(';')})
    # else:
        # sw.options = defaultdict(str)
    sw.options = defaultdict(str,options)
    if ext in sw.write_func:
        sw.write_func[ext]()
    else:
        print('Format {:s} is not supported yet'.format(ext))
        sys.exit()
    sw.file.close()


    # for atom in copy.d

class StructureWriter:
    def __init__(self):
        self.write_func = {'res': self._write_res,
                           'pdb': self._write_pdb,
                           'mol2': self._write_mol2,
                           'gjf': self._write_gjf,
                           'gro': self._write_gro,
                           'cif': self._write_cif,
                           'xyz': self._write_xyz
                           }
    def fd(self, input, f=float, d=0.0):
        '''fill default'''
        if input == '':
            input = f(d)
        return f(input)

    def get_value(self, prop_name, fill=None):
        try:
            prop_name = getattr(self.st, prop_name)
        except AttributeError:
            if fill is None:
                prop_name = fill
            elif len(fill) == len(self.st.cart_coord):
                prop_name = fill
            elif len(fill) < len(self.st.cart_coord) or isinstance(fill, str):
                prop_name = len(self.st.cart_coord) * [fill]
        return prop_name

    def _write_res(self):
        self.file.write('TITL {:s}\n'.format(self.basename))
        self.file.write(('CELL'+7*'{:>11.6f}'+'\n').format(0.0000, *self.st.cell_param))
        self.file.write('LATT -1\n')
        elem_set = list(set(self.st.elem))
        elem_str = ' '.join(elem_set)
        self.file.write('SFAC '+elem_str+'\n')
        for i, c in enumerate(self.st.fcoord):
            elem_name = self.st.elem[i]
            atom_code = elem_set.index(self.st.elem[i])+1
            clist = [elem_name, atom_code] + list(c) + [1.0, 0.0]
            s = '{:<6s}{:<3d}{:<14.8f}{:<14.8f}{:<14.8f}{:<11.5f}{:<10.5f}\n'.format(*clist)
            self.file.write(s)
        self.file.write('END')

    def _write_pdb(self):
        self.file.write('REMARK    Generated by Masagna\n')
        if self.st.period_flag == 1:
            self.file.write('CRYST1{:>9.3f}{:>9.3f}{:>9.3f}{:>7.2f}{:>7.2f}{:>7.2f} {:<11s}\n'
                            .format(*(self.st.cell_param+['P1'])))
            scale = np.matrix(self.st.cell_vect).I.T.tolist()
            for i in range(1, 4):
                self.file.write('SCALE{:<4d}{:>10.6f}{:>10.6f}{:>10.6f}{:5s}{:>10.5f}\n'
                                .format(*([i]+scale[i-1]+[' ']+[0.0])))
        atomname = self.st.getter('atomname')
        if not all(atomname):
            self.st.setter('atomname', self.st.elem)
        for i, a in enumerate(self.st.atoms):
            str1 = '{:<6s}{:>5d} {:<4s} {:3s} {:1s}{:>4d}    '\
                   .format('ATOM', a['sn'], a['atomname'], a['resname'],
                           a['chainid'], self.fd(a['resid'], f=int, d=1))
            str2 = '{:>8.3f}{:>8.3f}{:>8.3f}'.format(*a['coord'])
            str3 = '{:>6.2f}{:>6.2f}{:10s}'\
                   .format(self.fd(a['occupancy'], d=1), self.fd(a['bfactor']), ' ')
            str4 = '{:>2s}{:2s}\n'.format(a['elem'], a['formal_charge'])
            self.file.write(str1+str2+str3+str4)

    def _write_gro(self):
        atomname = self.st.getter('atomname')
        if not all(atomname):
            self.st.setter('atomname', self.st.elem)
        self.file.write('gro file generate by masagna, t= 0.0\n')
        self.file.write('{:d}\n'.format(len(self.st.atoms)))
        for i, a in enumerate(self.st.atoms):
            name_id = "{:5d}{:<5s}{:>5s}{:5d}"\
                      .format(self.fd(a['resid'], f=int, d=1),
                              self.fd(a['resname'], f=str, d='MOL'),
                              a['atomname'], a['sn'])
            coord = "{:8.3f}{:8.3f}{:8.3f}"\
                    .format(*[i/10 for i in self.st.coord[i]])
            if a['velocity'] == '':
                vel = '\n'
            else:
                vel = "{:8.4f}{:8.4f}{:8.4f}\n"\
                    .format(*a['velocity'])
            self.file.write(name_id+coord+vel)
        if len(self.st.cell_vect) == 3:
            v1, v2, v3 = self.st.cell_vect
            vlist = [v1[0], v2[1], v3[2], v1[1], v1[2], v2[0], v2[2], v3[0], v3[1]]
            vstr = ' '.join(['{:.5f}'.format(i/10) for i in vlist]) + '\n'
            self.file.write(vstr)

    def _write_mol2(self):
        atomname = self.st.getter('atomname')
        if not all(atomname):
            self.st.setter('atomname', self.st.elem)
        atomtype = self.st.getter('atomtype')
        if not all(atomtype):
            self.st.setter('atomtype', self.st.elem)
        self.file.write('@<TRIPOS>MOLECULE\n')
        self.file.write('{:s}\n'.format(self.st.basename))
        self.file.write('{:d} 0\n'.format(len(self.st.coord)))
        self.file.write('SMALL\n')
        self.file.write('NO_CHARGE\n')
        self.file.write('@<TRIPOS>ATOM\n')
        for i, a in enumerate(self.st.atoms):
            str1 = '{:<6d}{:<6s}'.format(a['sn'], a['atomname'])
            str2 = '{:<12.5f}{:<12.5f}{:<12.5f}'.format(*a['coord'])
            str3 = '{:<6s}'.format(a['atomtype'])
            if a['resid'] != '':
                str4 = '{:<6d}'.format(a['resid'])
                if a['resname'] != '':
                    str4 = str4 + '{:<6s}'.format(a['resname'])
                    if a['charge'] != '':
                        str4 = str4 + '{:<.6f}'.format(a['charge'])
            else:
                str4 = ''
            line = str1+str2+str3+str4+'\n'
            self.file.write(line)

    def _write_gjf(self):
        def_param = {'nproc':'8','charge':'0','spin':'1','mem':'4GB','extra':'','cpu':'',
                     'oldchk':'','chk':'{:s}.chk'.format(self.basename),'link':[],
                     'keywords':'pbe1pbe1 def2svp em(gd3bj)'}
        def_param.update({k:v for k,v in self.options.items() if k in def_param})
        self.file.write('%chk={:s}\n'.format(def_param['chk']))
        if def_param['oldchk']:
            if def_param['oldchk'] != def_param['chk']:
                self.file.write('%oldchk={:s}\n'.format(def_param['oldchk']))
        if def_param['cpu']:
            self.file.write('%cpu={:s}\n'.format(def_param['cpu']))
        else:
            self.file.write('%nprocshared={:s}\n'.format(def_param['nproc']))
        self.file.write('%mem={:s}\n'.format(def_param['mem']))
        self.file.write('#p {:s}\n'.format(def_param['keywords']))
        self.file.write('\n')
        self.file.write('{:s} generated by CoordMagic\n'.format(self.basename))
        self.file.write('\n')
        self.file.write('{:s} {:s}\n'.format(def_param['charge'],def_param['spin']))
        for i, a in enumerate(self.st.atoms):
            str1 = '{:<6s}'.format(a['elem'])
            str2 = '{:<12.5f}{:<12.5f}{:<12.5f}'.format(*a['coord'])
            line = str1+str2+'\n'
            self.file.write(line)
        if def_param['extra']:
            self.file.write('\n')
            self.file.write(def_param['extra'])
        if len(def_param['link']) > 0:
            for link in def_param['link']:
                if 'keywords' in link:
                    self.file.write('\n')
                    self.file.write('--Link1--\n')
                    self.file.write('%nprocshared={:s}\n'.format(def_param['nproc']))
                    self.file.write('%mem={:s}\n'.format(def_param['mem']))
                    if 'chk' in link:
                        if link['chk'] != def_param['chk']:
                            self.file.write('%chk={:s}\n'.format(link['chk']))
                            self.file.write('%oldchk={:s}\n'.format(def_param['chk']))
                    else:
                        self.file.write('%chk={:s}\n'.format(def_param['chk']))
                    self.file.write('#p {:s}\n'.format(link['keywords']))
                    if 'charge' in link and 'spin' in link:
                        self.file.write('\n')
                        self.file.write('{:s} generated by CoordMagic\n'.format(self.basename))
                        self.file.write('\n')
                        self.file.write('{:s} {:s}\n'.format(link['charge'], link['spin']))
        self.file.write('\n')
        self.file.write('\n')
        self.file.write('\n')

    def _write_xyz(self):
            if 'comment' in self.options:
                comment = self.options['comment']
            else:
                comment = ''
            self.file.write('{:d}\n'.format(len(self.st.atoms)))
            self.file.write('{:s}\n'.format(comment))
            for i, a in enumerate(self.st.atoms):
                str1 = '{:<6s}'.format(a['elem'])
                str2 = '{:>12.6f}{:>12.6f}{:>12.6f}'.format(*a['coord'])
                line = str1 + str2 + '\n'
                self.file.write(line)

    def _write_cif(self):
        self.file.write('data_'+self.st.basename+'\n')
        self.file.write('{:35s}{:.6f}\n'.format('_cell_length_a', self.st.cell_param[0]))
        self.file.write('{:35s}{:.6f}\n'.format('_cell_length_b', self.st.cell_param[1]))
        self.file.write('{:35s}{:.6f}\n'.format('_cell_length_c', self.st.cell_param[2]))
        self.file.write('{:35s}{:.6f}\n'.format('_cell_angle_alpha', self.st.cell_param[3]))
        self.file.write('{:35s}{:.6f}\n'.format('_cell_angle_beta', self.st.cell_param[4]))
        self.file.write('{:35s}{:.6f}\n'.format('_cell_angle_gamma', self.st.cell_param[5]))
        self.file.write('loop_\n'
                      '_atom_site_label\n'
                      '_atom_site_type_symbol\n'
                      '_atom_site_fract_x\n'
                      '_atom_site_fract_y\n'
                      '_atom_site_fract_z\n')
        atomname = self.st.getter('atomname')
        if not all(atomname):
            self.st.setter('atomname', self.st.elem)
        for i, a in enumerate(self.st.atoms):
            str1 = '{:7s}{:7s}'.format(a['atomname'], a['elem'])
            str2 = '{:14.8f}{:14.8f}{:14.8f}'.format(*a['fcoord'])
            self.file.write(str1+str2+'\n')
